{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import math\n",
    "from itertools import combinations, combinations_with_replacement"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "print('Using gpu: %s ' % torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# [Thinking like Transformers](https://arxiv.org/abs/2106.06981)\n",
    "\n",
    "Here we code our 'toy' GPT without any training in order to compute histograms. For the input sequence `<BOS>,a,a,b,a,b,c`, the output should be `0,3,3,2,3,2,1` as the letter `a` appears 3 times, the letter `b` 2 times and the letter `c` once. Each letter is replaced by its number of occurences (except `<BOS>` replaced by a `0`). \n",
    "\n",
    "## Self-Attention\n",
    "\n",
    "First start by coding your Self-Attention layer (do not worry about initialization for the moment)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SelfAttentionLayer(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.n_channels = config.n_channels\n",
    "        self.key_channels = config.key_channels\n",
    "        self.Query = nn.Linear(self.n_channels, self.key_channels, bias=False)\n",
    "        self.Key = nn.Linear(self.n_channels, self.key_channels, bias = False)\n",
    "        self.Value = nn.Linear(self.n_channels, self.n_channels, bias = False)\n",
    "           \n",
    "    def _init_id(self):\n",
    "        self.Query.weight.data = 100*torch.eye(self.key_channels, self.n_channels)\n",
    "        self.Key.weight.data = 100*torch.eye(self.key_channels,self.n_channels)\n",
    "        self.Value.weight.data = torch.eye(self.key_channels,self.n_channels)        \n",
    "        \n",
    "    def forward(self, x): # x (bs, T, ic)\n",
    "        Q = self.Query(x) # (bs, T, kc)\n",
    "        K = self.Key(x)/math.sqrt(self.key_channels) # (bs, T, kc)\n",
    "        V = self.Value(x) # (bs, T, oc)\n",
    "        A = # your code here\n",
    "        y = # your code here\n",
    "        return y, A"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check your implementation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class toy_config:\n",
    "    n_channels = 3\n",
    "    key_channels = 3\n",
    "    \n",
    "sa_toy = SelfAttentionLayer(toy_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input = torch.randn(5,10,3)\n",
    "y,A = sa_toy(input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.sum(A, dim=-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## identity GPT\n",
    "\n",
    "We first start with a simple example where we want to contruct the identity map. Clearly, in this case, we can just use the skip connections present in real transformer block. Instead, we will ignore these skip connections and use the self-attention layer. In this practical, we will ignore the layer norm.\n",
    "\n",
    "To make our life simpler, we encode `<BOS>` with a `0`, letter `a` with a `1` and so on...\n",
    "\n",
    "If we give as input the sequence `0,1,1,2,3,4,2,3,1`, we want to get the same sequence as output. This is clearly doable with a transformer block as follows:\n",
    "- take one-hot encoding of each token \n",
    "- take Query and Key matrices as `100*Id`\n",
    "- take Value matrix as `Id`\n",
    "As a result, the output of the self-attention layer will be the same as the input.\n",
    "\n",
    "Then take a Feed Forward Network which is simply the identity map as coded below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Block_id(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.attn = SelfAttentionLayer(config)\n",
    "        self.fake_mlp = (lambda x : x)\n",
    "        self.attn._init_id()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x, A = self.attn(x)\n",
    "        x = self.fake_mlp(x)\n",
    "        return x, A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_digits = 4\n",
    "class config:\n",
    "    n_channels=nb_digits+1\n",
    "    key_channels=nb_digits+1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bid = Block_id(config)\n",
    "one_sample = torch.tensor([[0.,0.,1.,0.,0.],[0.,1.,0.,0.,0.]]).unsqueeze(0)\n",
    "bid(one_sample)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now to have really the identity map, we need to project back the one-hot encoding and this can be done with a linear layer (with good weights initialization)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GPT_id(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.n_channels = config.n_channels\n",
    "        self.tok_emb = nn.Embedding(self.n_channels,self.n_channels)\n",
    "        self.block = Block_id(config)\n",
    "        self.head = nn.Linear(self.n_channels, 1, bias = False)\n",
    "        self._init_weights()\n",
    "        \n",
    "    def _init_weights(self):\n",
    "        #\n",
    "        # your code here\n",
    "        #\n",
    "        \n",
    "    def forward(self, idx):\n",
    "        x = self.tok_emb(idx)\n",
    "        x, A = self.block(x)\n",
    "        return self.head(x), A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gid = GPT_id(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "one_sample = torch.tensor([0,1,1,2,3,4,2,3,1]).unsqueeze(0)\n",
    "y, A = gid(one_sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y == one_sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(A[0,:,:].cpu().data, cmap='hot', interpolation='nearest')\n",
    "plt.colorbar()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## histogram GPT\n",
    "\n",
    "Now we need to adapt previous case to code our 'toy' transformer block and your 'toy' GPT to compute histograms:\n",
    "- you will need to find a good initialization for the Quey, Key and Value matrices\n",
    "- for the feed forward network, you can fake the mlp with any function you'd like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Block_hist(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.attn = SelfAttentionLayer(config)\n",
    "        self.fake_mlp = # your code here\n",
    "        self.attn._init_hist() # this need to be coded in your self attention layer\n",
    "\n",
    "    def forward(self, x):\n",
    "        x, A = self.attn(x)\n",
    "        x = self.fake_mlp(x)\n",
    "        return x, A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GPT_hist(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.n_channels = config.n_channels\n",
    "        self.tok_emb = nn.Embedding(self.n_channels,self.n_channels)\n",
    "        self.block = Block_hist(config)\n",
    "        self._init_weights()\n",
    "        \n",
    "    def _init_weights(self):\n",
    "        #\n",
    "        # your code here\n",
    "        #\n",
    "        \n",
    "        \n",
    "    def forward(self, idx):\n",
    "        x = self.tok_emb(idx)\n",
    "        x, A = self.block(x)\n",
    "        return x, A"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check your implementation by first choosing properly your configuration:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gh = GPT_hist(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "one_sample = torch.tensor([0,1,1,2,3,4,2,3,1]).unsqueeze(0)\n",
    "y, A = gh(one_sample)\n",
    "y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(A[0,:,:].cpu().data, cmap='hot', interpolation='nearest')\n",
    "plt.colorbar()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generating your dataset\n",
    "\n",
    "Now, we will use a 'micro' GPT to learn the task of histograms. Before that, we will use our 'toy' GPT to generate the dataset. Since GPT is equivariant (a permutation of the input will permute the output), we can always take as input a sequence ordered. We can indeed compute all possible different inputs and this number is not too high. For a sequence of lenght `seq_train=s` with at most `nb_digits=n`, there are ${s+n-1 \\choose n-1}$ possibilities. Now for each such sequence, we pass it through our toy GPT to get the label."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_train = 30\n",
    "nb_digits = 4\n",
    "comb = combinations_with_replacement(range(0,seq_train+1), nb_digits-1)\n",
    "\n",
    "def make_seq(c, seq_train):\n",
    "    c_l = [0] + list(c) + [seq_train]\n",
    "    len_seq = len(c_l)-1\n",
    "    return [c_l[i+1]-c_l[i] for i in range(len_seq)]\n",
    "\n",
    "l_comb =  [make_seq(c,seq_train) for c in comb]\n",
    "\n",
    "len(l_comb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "math.comb(seq_train+nb_digits-1, nb_digits-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_inputs(l_comb, nb_digits=nb_digits):\n",
    "    inputs = []\n",
    "    for t in l_comb:\n",
    "        curr = [0]\n",
    "        for (i,j) in enumerate(t):\n",
    "            curr += [i+1 for _ in range(j)]\n",
    "        inputs.append(torch.tensor(np.array(curr)))\n",
    "    return inputs\n",
    "\n",
    "def make_loader(len_seq,nb_digits):\n",
    "    comb = combinations_with_replacement(range(0,len_seq+1), nb_digits-1)\n",
    "    l_comb =  [make_seq(c,len_seq) for c in comb]\n",
    "    inputs = make_inputs(l_comb)\n",
    "    labels = [(gh(d.unsqueeze(0))[0].squeeze(0).squeeze(1)).type(torch.LongTensor) for d in inputs]\n",
    "    dataset = list(zip(inputs,labels))\n",
    "    len_in = len(dataset)\n",
    "    loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)\n",
    "    return loader, len_in, inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, size_train, inputs_train = make_loader(seq_train,nb_digits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_in = next(iter(train_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_in[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_in[1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_in[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_in[1][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Coding 'micro' GPT\n",
    "\n",
    "Now we need to code the 'micro' GPT used for learning. The game here is to reuse our `SelfAttentionLayer` above without any modification. The only part that is modified is the hard-coded `fake_mlp` which is now replaced by a real MLP."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Block(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.attn = SelfAttentionLayer(config)\n",
    "        self.mlp = # your code here\n",
    "\n",
    "    def forward(self, x, verbose=False): # x (bs, T,ic)\n",
    "        #\n",
    "        # your code here\n",
    "        #\n",
    "        if verbose:\n",
    "            return x, A\n",
    "        else:\n",
    "            return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GPT(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.n_channels = config.n_channels\n",
    "        self.nb_digits = config.nb_digits\n",
    "        self.tok_emb = # your code here \n",
    "        self.block = Block(config)\n",
    "        self.head = # your code here \n",
    "        \n",
    "    def forward(self, idx, targets=None, verbose=False):\n",
    "        # shape of idx: (bs, len) 0=bos and 1...nb_digits\n",
    "        # shape of targets: (bs, len)\n",
    "        #\n",
    "        # your code here\n",
    "        #\n",
    "        \n",
    "        loss = None\n",
    "        if targets is not None:\n",
    "            loss = # your code here\n",
    "        if verbose:\n",
    "            return logits, loss, A\n",
    "        else:\n",
    "            return logits, loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class config_gpt:\n",
    "    nb_digits = nb_digits\n",
    "    n_channels = 32 \n",
    "    key_channels = 64 \n",
    "    max_hist = seq_train+1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gptmini = GPT(config_gpt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logits, _ = gptmini(batch_in[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logits.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_,preds = torch.max(logits,-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_in[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.sum(preds == batch_in[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, dataloader, size, epochs=1, optimizer=None):\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        running_loss = 0.0\n",
    "        running_corrects = 0\n",
    "        n_batch = 0\n",
    "        for inputs,targets in dataloader:\n",
    "            inputs = inputs.to(device)\n",
    "            targets = targets.to(device)\n",
    "            logits, loss = model(inputs,targets)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            _,preds = torch.max(logits,-1)\n",
    "           \n",
    "            running_corrects += torch.true_divide(torch.sum(preds == targets.data),targets.shape[0]*targets.shape[1])\n",
    "            running_loss +=  loss.data.item()\n",
    "            n_batch += 1\n",
    "        epoch_loss = running_loss /n_batch\n",
    "        epoch_acc = running_corrects.data.item() /n_batch\n",
    "        print('Loss: {:.4f} Acc: {:.4f}'.format(\n",
    "                     epoch_loss, epoch_acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gptmini = GPT(config_gpt)\n",
    "gptmini = gptmini.to(device)\n",
    "lr = 0.01\n",
    "optimizer = torch.optim.Adam(gptmini.parameters(),lr = lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len_train = (seq_train+1)*size_train\n",
    "train_model(gptmini,train_loader,size_train,15,optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 0.005\n",
    "optimizer = torch.optim.Adam(gptmini.parameters(),lr = lr)\n",
    "train_model(gptmini,train_loader,len_train,15,optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 0.001\n",
    "optimizer = torch.optim.Adam(gptmini.parameters(),lr = lr)\n",
    "train_model(gptmini,train_loader,len_train,15,optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 0.0001\n",
    "optimizer = torch.optim.Adam(gptmini.parameters(),lr = lr)\n",
    "train_model(gptmini,train_loader,len_train,15,optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "one_batch = batch_in[0].to(device)\n",
    "logits, loss, A = gptmini(one_batch,verbose=True)\n",
    "A.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 45\n",
    "plt.imshow(A[k,:,:].cpu().data, cmap='hot', interpolation='nearest')\n",
    "plt.colorbar()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dldiy",
   "language": "python",
   "name": "dldiy"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
